from Network.network import Network
from Network.network_utils import get_inplace_acti
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

BIAS = True
def create_layers(inp_dim, out_dim, activation='none', norm=False, use_bias=True):
    if activation == 'crelu': out_dim = int(out_dim / 2) 
    layer = [nn.Conv1d(int(inp_dim), int(out_dim), 1, bias=use_bias)]
    if norm: layer = [nn.GroupNorm(1, inp_dim)] + layer
    layer = layer + [get_inplace_acti(activation)]
    return layer

class ConvNetwork(Network): # basic 1d conv network 
    def __init__(self, args):
        super().__init__(args)
        self.object_dim = args.object_dim
        self.output_dim = args.output_dim # slightly different in meaning from num_outputs
        self.is_crelu = args.activation == "crelu"
        if args.activation_final == "crelu":
            self.activation_final = get_inplace_acti("leakyrelu")
        sizes = [self.object_dim] + self.hs + [self.output_dim]
        activations = [args.activation for i in range(len(sizes)-2)] + ['none'] # last layer is none
        layers = list()
        for inp_dim, out_dim, acti in zip(sizes, sizes[1:], activations):
            layers += create_layers(inp_dim, out_dim, activation=acti, norm = self.use_layer_norm, use_bias=args.use_bias)
        if args.dropout > 0: # only supports last layer dropout for now
            layers = layers[:-1] + [nn.Dropout(args.dropout)] + [layers[-1]]
        self.model = nn.Sequential(*layers)
        self.train()
        self.reset_network_parameters()

    def forward(self, x):
        # expects shape [batch, point_dim, num_channels]
        x = self.model(x)
        x = self.activation_final(x)
        return x
